#!/usr/bin/env python
from datetime import datetime
import getopt
import os
import pwd
import re
import sys

# global variables stored options value
is_debug = False
is_long = False
has_print_title = False
log_file = ""
djob_path = ''
job_ids = []
job_names = []
job_states = []
job_users = []
job_partitions = []
job_err_info = []
# Length Limit of string displayed
PARTITION_STR_LEN_LIMIT = 9
NAME_STR_LEN_LIMIT = 8
USER_STR_LEN_LIMIT = 8
STATE_STR_LEN_LIMIT = 8
TIME_STR_LEN_LIMIT = 10
TIME_LIMIT_STR_LEN_LIMIT = 9
# display items
TOTAL_SECONDS = 'TOTAL_SECONDS'
DISPLAY_ID = 'ID'
DISPLAY_NAME = 'NAME'
DISPLAY_USER = 'USER'
DISPLAY_STATE = 'STATE'
DISPLAY_QUEUE = 'QUEUE'
DISPLAY_TIME_OUT = 'TIME_OUT'
DISPLAY_STATE_REASON_CODE = 'STATE_REASON_CODE'
DISPLAY_EXEC_NODE = 'EXEC_NODE'
DISPLAY_RUN_TIME = 'RUN_TIME'
DISPLAY_JOB_ID = 'JOBID'
DISPLAY_PARTITION = 'PARTITION'
DISPLAY_TIME = 'TIME'
DISPLAY_NODES = 'NODES'
DISPLAY_TIME_LIMIT = 'TIME_LIMI'
DISPLAY_NODELIST = 'NODELIST(REASON)'
DISPLAY_STATE_SHORT = 'ST'
DISPLAY_SHORT_FORMAT = '{:>13}{:>10}{:>9}{:>9}{:>3}{:>11}{:>7} {}'
DISPLAY_LONG_FORMAT = '{:>13}{:>10}{:>9}{:>9}{:>9}{:>11} {:<10}{:>7} {}'
DISPLAY_TIME_FORMAT = '%a %b %d %H:%M:%S %Y'
# index of djob results
JOB_QUERY_ID_INDEX = 0
JOB_QUERY_NAME_INDEX = 1
JOB_QUERY_USER_INDEX = 2
JOB_QUERY_STATE_INDEX = 3
JOB_QUERY_QUEUE_INDEX = 4
JOB_QUERY_IEN = 8
# index of djob -x -CO results
JOB_CUSTOM_ID_INDEX = 0
JOB_CUSTOM_TIME_OUT_INDEX = 1
JOB_CUSTOM_EXEC_NODE_INDEX = 2
JOB_CUSTOM_STATE_REASON_INDEX = 3
JOB_CUSTOM_RUN_TIME_INDEX = 4
JOB_CUSTOM_VALID_LEN = 5


def get_help_info():
    help_info = '''Usage: squeue [OPTIONS]
  -j, --job=job_id(s)             comma separated list of jobs IDs to view, default is all
  -l, --long                      long report
  -n, --name=job_name(s)          comma separated list of job names to view
  -p, --partition=partition(s)    comma separated list of partitions to view, default is all
  -t, --states=states             comma separated list of states to view, 
                                  only support pending, running and suspended
  -u, --user=user_name(s)         comma separated list of users to view,
                                  this option only takes effect while user is administrator,
                                  normal user can not query other user's job

Help options:
  --help                          show this help message
  --usage                         display a brief summary'''
    return help_info


def get_usage_info():
    usage_info = '''Usage: squeue [--job jobid] [-n name] [-p partitions]
              [-t states] [-u user_name] [--usage] [-l]'''
    return usage_info


def log(content):
    if is_debug:
        f = open(log_file, 'a+')
        f.write(content + '\n')
        f.close()


def init_log_file_name():
    global log_file
    file_name = '%s.%d.%d' % ('squeue', os.getuid(), os.getpid())
    log_file = os.path.join('/tmp/', file_name)


# init options from environment
def init_config_with_env():
    global is_debug
    debug_env = os.getenv('SLURM_TO_DONAU_DEBUG')
    if debug_env is not None and debug_env.lower() == 'true':
        is_debug = True
        init_log_file_name()

    global job_names
    job_names_env = os.getenv('SQUEUE_NAMES')
    if job_names_env is not None:
        log('get job names from environment: ' + job_names_env)
        job_names = job_names_env.split(',', -1)

    global job_states
    job_states_env = os.getenv('SQUEUE_STATES')
    if job_states_env is not None:
        log('get job states from environment: ' + job_states_env)
        job_states = job_states_env.split(',', -1)

    global job_users
    job_users_env = os.getenv('SQUEUE_USERS')
    if job_users_env is not None:
        log('get job users from environment: ' + job_users_env)
        job_users = job_users_env.split(',', -1)

    global job_partitions
    job_partitions_env = os.getenv('SQUEUE_PARTITION')
    if job_partitions_env is not None:
        log('get job users from environment: ' + job_partitions_env)
        job_partitions = job_partitions_env.split(',', -1)


def get_invalid_option(err_message):
    if 'not recognized' in err_message:
        options = re.findall(r'option (.*) not recognized', err_message)
        if len(options) == 1:
            return options[0]
    if 'requires argument' in err_message:
        options = re.findall(r'option (.*) requires argument', err_message)
        if len(options) == 1:
            print_single_option(options[0])
            sys.exit(1)
    return ''


def check_user(user_name):
    try:
         pwd.getpwnam(user_name)
    except KeyError:
        print('squeue: error: Invalid user: %s' % user_name)
        sys.exit(1)
    return True


def print_invalid_option(option):
    if len(option) == 0:
        return
    if len(option) > 1 and option[1] == '-':
        print("squeue: invalid option '" + option + "'")
    else:
        print("squeue: invalid option -- '" + option[1:] + "'")
    print('''Try "squeue --help" for more information''')


def print_single_option(option):
    if len(option) < 2:
        return
    if option.startswith('--'):
        print("squeue: option '%s' requires an argument" % option)
        print('Try "squeue --help" for more information')
        return
    if option[0] == '-':
        print("squeue: option requires an argument -- '%s'" % option[1:])
        print('Try "squeue --help" for more information')
        return


def print_invalid_state(state):
    print('squeue: error: Invalid job state specified: ' + state)
    print('squeue: error: Valid job states include: PENDING,RUNNING,SUSPENDED')


def print_title():
    global has_print_title
    if has_print_title:
        return
    has_print_title = True
    if is_long:
        time = datetime.now()
        print(time.strftime(DISPLAY_TIME_FORMAT))
        print(DISPLAY_LONG_FORMAT.format(DISPLAY_JOB_ID,
                                         DISPLAY_PARTITION,
                                         DISPLAY_NAME,
                                         DISPLAY_USER,
                                         DISPLAY_STATE,
                                         DISPLAY_TIME,
                                         DISPLAY_TIME_LIMIT,
                                         DISPLAY_NODES,
                                         DISPLAY_NODELIST))
    else:
        print(DISPLAY_SHORT_FORMAT.format(DISPLAY_JOB_ID,
                                          DISPLAY_PARTITION,
                                          DISPLAY_NAME,
                                          DISPLAY_USER,
                                          DISPLAY_STATE_SHORT,
                                          DISPLAY_TIME,
                                          DISPLAY_NODES,
                                          DISPLAY_NODELIST))


def convert_donau_job_states(slurm_job_states):
    donau_job_states = []
    has_pending = False
    has_running = False
    has_suspend = False
    for state in slurm_job_states:
        state_upper = state.upper()
        if state_upper == 'PD' or state_upper == 'PENDING':
            if has_pending:
                continue
            has_pending = True
            donau_job_states.append('PENDING')
            donau_job_states.append('WAITING')
        elif state_upper == 'R' or state_upper == 'RUNNING':
            if has_running:
                continue
            has_running = True
            donau_job_states.append('RUNNING')
        elif state_upper == 'S' or state_upper == 'SUSPENDED':
            if has_suspend:
                continue
            donau_job_states.append('STOPPED')
        else:
            print_invalid_state(state)
            sys.exit(1)
    return donau_job_states


def parse_job_info(line):
    job_info = {}
    info = line.split()
    if len(info) < JOB_QUERY_IEN:
        return job_info
    # filter job_name
    job_name = info[JOB_QUERY_NAME_INDEX]
    if len(job_names) > 0 and job_name not in job_names:
        log(job_name + ' is not in filter job name list: ' + ' '.join(job_names))
        return job_info
    job_info[DISPLAY_NAME] = job_name
    job_info[DISPLAY_ID] = info[JOB_QUERY_ID_INDEX]
    job_info[DISPLAY_USER] = info[JOB_QUERY_USER_INDEX]
    job_info[DISPLAY_STATE] = info[JOB_QUERY_STATE_INDEX]
    job_info[DISPLAY_QUEUE] = info[JOB_QUERY_QUEUE_INDEX]
    return job_info


def check_result_error(out_line):
    global job_err_info
    # check permission
    if 'access to job' in out_line:
        err_out = ' '.join(out_line.split()[1:])
        job_err_info.append(err_out)
        return False
    if 'only permitted' in out_line:
        err_list = out_line.split()
        job_err_info.append(' '.join(err_list[1:]))
        return False
    if 'the permission' in out_line:
        err_list = out_line.split()
        job_err_info.append(' '.join(err_list[1:]))
        return False
    if 'Job query queue' in out_line and 'invalid' in out_line:
        return False
    if 'error' in out_line and 'token' in out_line:
        job_err_info.append(out_line)
        return False
    if 'not exist' in out_line:
        job_id = out_line.split()[0]
        err_msg = 'slurm_load_jobs error: Invalid job id specified'
        if job_id.isdigit():
            err_msg = 'slurm_load_jobs error: Invalid job id %s specified' % job_id
        job_err_info.append(err_msg)
        return False
    if out_line.startswith('---'):
        return False
    out_job_info = out_line.split()
    if len(out_job_info) > 0 and out_job_info[0].isdigit():
        return True
    if len(out_job_info) > 0 and out_job_info[0] == 'JOB_ID':
        return False
    job_err_info.append(out_line)
    return False


def is_administrator():
    cmd = 'dadmin show config -n cluster.administrators'
    output = os.popen(cmd).read()
    out_lines = output.splitlines()
    if len(out_lines) == 1:
        print(out_lines[0])
        sys.exit(1)
    out_lines = output.splitlines()
    if len(out_lines) < 2:
        print('check administrator failed')
        sys.exit(1)
    adm_user_info = out_lines[1].split()
    if len(adm_user_info) < 3:
        print('check administrator failed')
        sys.exit(1)
    adm_user = adm_user_info[2]
    return adm_user == pwd.getpwuid(os.getuid())[0]


def execute_job_query(query_command):
    log("query_command: %s" % query_command)
    job_dict = {}
    # execute command and get job list
    output = os.popen(query_command).read()
    job_info = output.splitlines()
    if len(job_info) <= 1:
        return job_dict
    log('get output: %s' % output)
    if 'no matches' in output:
        return job_dict
    # check result error
    valid_info = []
    for info in job_info:
        if not check_result_error(info):
            log('check djob result error: %s' % info)
            print_title()
        else:
            valid_info.append(info)

    for line in valid_info:
        info = parse_job_info(line)
        if len(info) > 0 and info[DISPLAY_ID] is not None:
            job_dict[info[DISPLAY_ID]] = info
    return job_dict


def get_job_dict():
    global djob_path
    query_command = djob_path
    # check job_users
    # -u/--user only takes effect while user is administrator
    if is_administrator():
        log('current user is administrator')
        if len(job_users) > 0:
            query_command += ' -u ' + '\'' + ' '.join(job_users) + '\''
        else:
            # is user is administrator, default all the users
            query_command += ' -u all'
    else:
        log('current user is not administrator')
        support_users = []
        for user_name in job_users:
            # ignore invalid user name
            if not check_user(user_name):
                continue
            support_users.append(user_name)
        # support user without valid users
        if len(job_users) > 0 and len(support_users) == 0:
            return {}
        if len(support_users) > 0:
            query_command += ' -u ' + '\'' + ' '.join(support_users) + '\''
    # filter job_state
    donau_job_states = convert_donau_job_states(job_states)
    if len(donau_job_states) == 0:
        donau_job_states = ['WAITING', 'PENDING', 'RUNNING', 'STOPPED']
    query_command += ' -s ' + '\'' + ' '.join(donau_job_states) + '\''

    # filter job partition
    # for multi queues/partitions query, if one of queue_names is invalid:
    # donau would return failure and not display jobs in valid queues
    # slurm would filter failure and display jobs in valid queues
    if len(job_partitions) == 0:
        # specify job IDs
        for id in job_ids:
            if not id.isdigit():
                print('squeue: error: Invalid job id: %s' % id)
                sys.exit(1)
        if len(job_ids) > 0:
            query_command += ' ' + ' '.join(job_ids)
        return execute_job_query(query_command)
    else:
        base_command = query_command
        job_dict = {}
        for partition in job_partitions:
            tmp_command = base_command + ' -q ' + partition
            # specify job IDs
            for id in job_ids:
                if not id.isdigit():
                    print('squeue: error: Invalid job id: %s' % id)
                    sys.exit(1)
            if len(job_ids) > 0:
                tmp_command += ' ' + ' '.join(job_ids)
            log('query command: %s' % tmp_command)

            single_query_result = execute_job_query(tmp_command)
            for k, v in single_query_result.items():
                job_dict[k] = v
        return job_dict


def adjust_display_sec(second):
    if second < 10:
        return '0%d' % second
    else:
        return str(second)


def get_display_runtime(total_seconds):
    hours, remain = divmod(total_seconds, 3600)
    minutes, seconds = divmod(remain, 60)
    if hours == 0:
        # in slurm style, minute will display '0'
        return '%d:%s' % (minutes, adjust_display_sec(seconds))
    else:
        days, remain_hours = divmod(hours, 24)
        if days == 0:
            return '%d:%d:%s' % (hours, minutes, adjust_display_sec(seconds))
        else:
            return '%d-%d:%d:%s' % (days, remain_hours, minutes, adjust_display_sec(seconds))


def cnvert_to_slurm_state(state):
    if state == 'RUNNING':
        if is_long:
            return 'RUNNING'
        return 'R'
    if state == 'WAITING' or state == 'PENDING':
        if is_long:
            return 'PENDING'
        return 'PD'
    if state == 'STOPPED':
        if is_long:
            return 'SUSPENDED'
        return 'S'
    return '-'


def covert_time_out(time_out):
    if not time_out.isdigit():
        return 'UNLIMITED'
    if int(time_out) == 0:
        return 'UNLIMITED'
    return get_display_runtime(int(time_out))


def print_job_info(info):
    display_job_id = info.get(DISPLAY_ID)
    display_partition = info.get(DISPLAY_QUEUE)
    if len(display_partition) > PARTITION_STR_LEN_LIMIT:
        display_partition = display_partition[:PARTITION_STR_LEN_LIMIT]
    display_name = info.get(DISPLAY_NAME)
    if len(display_name) > NAME_STR_LEN_LIMIT:
        display_name = display_name[:NAME_STR_LEN_LIMIT]
    display_user = info.get(DISPLAY_USER)
    if len(display_user) > USER_STR_LEN_LIMIT:
        display_user = display_user[:USER_STR_LEN_LIMIT]
    display_state = '-'
    donau_state = info.get(DISPLAY_STATE)
    if donau_state is not None:
        display_state = cnvert_to_slurm_state(donau_state)
    if len(display_state) > STATE_STR_LEN_LIMIT:
        display_state = display_state[:STATE_STR_LEN_LIMIT]
    display_time = info.get(DISPLAY_RUN_TIME)
    if len(display_time) > TIME_STR_LEN_LIMIT:
        display_time = display_time[:TIME_STR_LEN_LIMIT]
    display_nodes = '-'
    display_node_list = '-'
    exec_node_list = info.get(DISPLAY_EXEC_NODE)
    if len(exec_node_list) > 0:
        display_node_list = ','.join(exec_node_list)
        display_nodes = str(len(exec_node_list))
    elif info.get(DISPLAY_STATE_REASON_CODE) is not None:
        display_node_list = info.get(DISPLAY_STATE_REASON_CODE)
    if is_long:
        display_time_out = 'UNLIMITED'
        time_out = info.get(DISPLAY_TIME_OUT)
        if time_out is not None:
            display_time_out = covert_time_out(time_out)
        print(DISPLAY_LONG_FORMAT.format(display_job_id,
                                         display_partition,
                                         display_name,
                                         display_user,
                                         display_state,
                                         display_time,
                                         display_time_out,
                                         display_nodes,
                                         display_node_list))
    else:
        print(DISPLAY_SHORT_FORMAT.format(display_job_id,
                                          display_partition,
                                          display_name,
                                          display_user,
                                          display_state,
                                          display_time,
                                          display_nodes,
                                          display_node_list))


def print_query_job_info(info_dict):
    global djob_path
    query_custom_command = '%s -x -CO "JOB_ID JOB_TIMEOUT EXEC_NODE MAIN_STATE_REASON_CODE MRUN_TIME"' % djob_path
    for job_id in info_dict.keys():
        query_custom_command = '%s %s' % (query_custom_command, job_id)
    log('query_custom_command: %s' % query_custom_command)
    # execute command and get job list
    output = os.popen(query_custom_command).read()
    log('query detail result: %s' % output)
    query_out_list = output.splitlines()
    for line in query_out_list:
        if len(line) < JOB_CUSTOM_VALID_LEN:
            continue
        query_info = line.split()
        job_id = query_info[JOB_CUSTOM_ID_INDEX]
        if not job_id.isdigit():
            continue
        info = info_dict[job_id]
        # time out job level, just override
        info[DISPLAY_TIME_OUT] = query_info[JOB_CUSTOM_TIME_OUT_INDEX]
        # state reason code is task level
        state_reson_code = query_info[JOB_CUSTOM_STATE_REASON_INDEX]
        if state_reson_code is not None and state_reson_code != '-':
            info[DISPLAY_STATE_REASON_CODE] = state_reson_code
        if info.get(DISPLAY_EXEC_NODE) is None:
            info[DISPLAY_EXEC_NODE] = []
        # exec node is task level, append all the execute nodes
        exec_node = query_info[JOB_CUSTOM_EXEC_NODE_INDEX]
        if exec_node is not None and exec_node != '-':
            info[DISPLAY_EXEC_NODE].append(exec_node)
        # run_time is task level, record the maximum value for all tasks
        run_time = query_info[JOB_CUSTOM_RUN_TIME_INDEX]
        if run_time.isdigit():
            record_seconds = info.get(TOTAL_SECONDS)
            if record_seconds is not None and record_seconds >= int(run_time):
                continue
            info[DISPLAY_RUN_TIME] = get_display_runtime(int(run_time))
        else:
            # if job is not allocated yet, set default 0:00
            info[DISPLAY_RUN_TIME] = '0:00'

    job_id_list = []
    for job_id in info_dict.keys():
        job_id_list.append(int(job_id))
    job_id_list.sort()
    for job_id in job_id_list:
        info = info_dict.get(str(job_id))
        if info.get(DISPLAY_EXEC_NODE) is None:
            continue
        info[DISPLAY_EXEC_NODE] = list(set(info.get(DISPLAY_EXEC_NODE)))
        print_job_info(info)

    return


def check_donau_cli_env():
    sched_cli_home = os.getenv('CCSCHEDULER_CLI_HOME')
    ccs_cli_home = os.getenv('CCS_CLI_HOME')
    if sched_cli_home is None or ccs_cli_home is None:
        print('cli env home is not configured, '
              'source configure \'profile.env\'(bash) or \'cshrc.env\'(csh) first')
        sys.exit(1)
    global djob_path
    djob_path = os.path.join(ccs_cli_home, 'bin', 'djob')
    if not os.path.exists(djob_path):
        print('djob path is not exist,'
              'source configure \'profile.env\'(bash) or \'cshrc.env\'(csh) first')
        sys.exit(1)


if __name__ == '__main__':
    if os.getuid() == 0:
        print('Permission denied. The root user is not allowed to operate.')
        sys.exit(1)
    check_donau_cli_env()
    init_config_with_env()
    short_options = '-j:-l-n:-t:-u:-p:'
    long_options = ['job=', 'help', 'long', 'name=', 'states=', 'user=', 'partition=', 'usage']
    try:
        opts, args = getopt.getopt(sys.argv[1:], short_options, long_options)
        # options from command have higher priority than environment
        for opt_name, opt_value in opts:
            if opt_name == '--help':
                print(get_help_info())
                sys.exit()
            if opt_name == '--usage':
                print(get_usage_info())
                sys.exit()
            if opt_name in ('-j', '--job'):
                log("get job IDs from command options: %s" % opt_value)
                job_ids = opt_value.split(",", -1)
            if opt_name in ('-n', '--name'):
                log("get job names from command options: %s" % opt_value)
                job_names = opt_value.split(",", -1)
            if opt_name in ('-l', '--long'):
                is_long = True
            if opt_name in ('-t', '--states'):
                log("get job states from command options: %s" % opt_value)
                job_states = opt_value.split(",", -1)
            if opt_name in ('-u', '--user'):
                log("get job users from command options: %s" % opt_value)
                job_users = opt_value.split(",", -1)
            if opt_name in ('-p', '--partition'):
                log("get job partitions from command options: %s" % opt_value)
                job_partitions = opt_value.split(",", -1)
        # print remained failed args
        if len(args) > 0:
            log('remain arguments: %s' % ''.join(args))
            print('squeue: error: Unrecognized option: ' + args[0])
            print(get_usage_info())
            sys.exit(1)
    except getopt.GetoptError as err:
        invalid_option = get_invalid_option(str(err))
        if len(invalid_option) != 0:
            print_invalid_option(invalid_option)
        else:
            print(str(err))
        sys.exit(1)

    job_info_dict = get_job_dict()
    print_title()
    if len(job_info_dict) > 0:
        print_query_job_info(job_info_dict)
    if len(job_err_info) > 0:
        print('---------------------------------------------')
    for err_info in job_err_info:
        print(err_info)
